{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tutorial 1. Quick Start ResNet18 on CIFAR10. \n",
    "\n",
    "\n",
    "In this tutorial, we will show \n",
    "\n",
    "- How to end-to-end train and compress a ResNet18 from scratch on CIFAR10 to get a compressed ResNet18.\n",
    "- The compressed ResNet18 achives both **high performance** and **significant FLOPs and parameters reductions** than the full model. \n",
    "- The compressed ResNet18 **reduces about 91% parameters and 80% FLOPs** to achieve **92.86% accuracy** only lower than the baseline by **0.16%**.\n",
    "- More detailed DHSPG optimizer setup.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1. Create OTO instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from backends import resnet18_cifar10\n",
    "from only_train_once import OTO\n",
    "\n",
    "model = resnet18_cifar10()\n",
    "dummy_input = torch.zeros(1, 3, 32, 32)\n",
    "oto = OTO(model=model.cuda(), dummy_input=dummy_input.cuda())"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### (Optional) Visualize the dependancy graph of DNN for ZIG partitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A ResNet_zig.gv.pdf will be generated to display the depandancy graph.\n",
    "oto.visualize_zigs()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 2. Dataset Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "from torchvision.datasets import CIFAR10\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "trainset = CIFAR10(root='cifar10', train=True, download=True, transform=transforms.Compose([\n",
    "            transforms.RandomHorizontalFlip(),\n",
    "            transforms.RandomCrop(32, 4),\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))\n",
    "testset = CIFAR10(root='cifar10', train=False, download=True, transform=transforms.Compose([\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]))\n",
    "\n",
    "trainloader =  torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3. Setup DHSPG optimizer\n",
    "\n",
    "The following main hyperparameters need to be taken care.\n",
    "\n",
    "- `variant`: The optimizer that is used for training the baseline full model. \n",
    "- `lr`: The initial learning rate.\n",
    "- `weight_decay`: Weight decay as standard DNN optimization.\n",
    "- `target_group_sparsity`: The target group sparsity, typically higher group sparsity refers to more FLOPs and model size reduction, meanwhile may regress model performance more.\n",
    "- `start_pruning_steps`: The number of steps that start to prune. \n",
    "- `epsilon`: The cofficient [0, 1) to control the aggresiveness of group sparsity exploration. Higher value means more aggressive group sparsity exploration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = oto.dhspg(\n",
    "    variant='sgd', \n",
    "    lr=0.1, \n",
    "    target_group_sparsity=0.7,\n",
    "    weight_decay=1e-4,\n",
    "    start_pruning_steps=50 * len(trainloader), # start pruning after 50 epochs\n",
    "    # lmbda=1e-2, # enable this argument if group sparsity production is not satisfactory on some corner environments.\n",
    "    epsilon=0.95)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 4. Train ResNet18 as normal."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, loss: 1.67, omega:4122.27, group_sparsity: 0.00, acc1: 0.3347\n",
      "Epoch: 1, loss: 1.15, omega:4112.58, group_sparsity: 0.00, acc1: 0.4246\n",
      "Epoch: 2, loss: 0.88, omega:4103.96, group_sparsity: 0.00, acc1: 0.5233\n",
      "Epoch: 3, loss: 0.71, omega:4094.89, group_sparsity: 0.00, acc1: 0.6019\n",
      "Epoch: 4, loss: 0.60, omega:4084.36, group_sparsity: 0.00, acc1: 0.7093\n",
      "Epoch: 5, loss: 0.52, omega:4073.36, group_sparsity: 0.00, acc1: 0.7460\n",
      "Epoch: 6, loss: 0.47, omega:4062.06, group_sparsity: 0.00, acc1: 0.7607\n",
      "...\n",
      "Epoch: 295, loss: 0.01, omega:1099.61, group_sparsity: 0.70, acc1: 0.9265\n",
      "Epoch: 296, loss: 0.01, omega:1099.60, group_sparsity: 0.70, acc1: 0.9276\n",
      "Epoch: 297, loss: 0.01, omega:1099.60, group_sparsity: 0.70, acc1: 0.9267\n",
      "Epoch: 298, loss: 0.01, omega:1099.59, group_sparsity: 0.70, acc1: 0.9278\n",
      "Epoch: 299, loss: 0.01, omega:1099.59, group_sparsity: 0.70, acc1: 0.9286\n"
     ]
    }
   ],
   "source": [
    "from utils.utils import check_accuracy\n",
    "\n",
    "max_epoch = 300\n",
    "model.cuda()\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "# Every 75 epochs, decay lr by 10.0\n",
    "lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1) \n",
    "\n",
    "for epoch in range(max_epoch):\n",
    "    f_avg_val = 0.0\n",
    "    model.train()\n",
    "    lr_scheduler.step()\n",
    "    for X, y in trainloader:\n",
    "        X = X.cuda()\n",
    "        y = y.cuda()\n",
    "        y_pred = model.forward(X)\n",
    "        f = criterion(y_pred, y)\n",
    "        optimizer.zero_grad()\n",
    "        f.backward()\n",
    "        f_avg_val += f\n",
    "        optimizer.step()\n",
    "    group_sparsity, omega = optimizer.compute_group_sparsity_omega()\n",
    "    accuracy1, accuracy5 = check_accuracy(model, testloader)\n",
    "    f_avg_val = f_avg_val.cpu().item() / len(trainloader)\n",
    "    print(\"Epoch: {ep}, loss: {f:.2f}, omega:{om:.2f}, group_sparsity: {gs:.2f}, acc1: {acc:.4f}\".format(ep=epoch, f=f_avg_val, om=omega, gs=group_sparsity, acc=accuracy1))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 5. Get compressed model in ONNX format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A ResNet_compressed.onnx will be generated. \n",
    "oto.compress()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (Optional) Compute FLOPs and number of parameters before and after OTO training\n",
    "\n",
    "The compressed ResNet18 only uses 8.8% of the parameters and 20.3% of the FLOPs compared to the full model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full FLOPs (M): 555.42. Compressed FLOPs (M): 112.52. Reduction Ratio: 0.7974\n",
      "Full # Params: 11173962. Compressed # Params: 981808. Reduction Ratio: 0.9121\n"
     ]
    }
   ],
   "source": [
    "full_flops = oto.compute_flops()\n",
    "compressed_flops = oto.compute_flops(compressed=True)\n",
    "full_num_params = oto.compute_num_params()\n",
    "compressed_num_params = oto.compute_num_params(compressed=True)\n",
    "\n",
    "print(\"Full FLOPs (M): {f_flops:.2f}. Compressed FLOPs (M): {c_flops:.2f}. Reduction Ratio: {f_ratio:.4f}\"\\\n",
    "      .format(f_flops=full_flops, c_flops=compressed_flops, f_ratio=1 - compressed_flops/full_flops))\n",
    "print(\"Full # Params: {f_params}. Compressed # Params: {c_params}. Reduction Ratio: {f_ratio:.4f}\"\\\n",
    "      .format(f_params=full_num_params, c_params=compressed_num_params, f_ratio=1 - compressed_num_params/full_num_params))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (Optional) Check the compressed model accuracy\n",
    "\n",
    "#### Both full and compressed model should return the exact same accuracy.\n",
    "\n",
    "Compared to the [baseline of full ResNet18 on CIFAR10 (93.02% top-1 accuracy)](https://github.com/kuangliu/pytorch-cifar), the compressed ResNet18 only regresses 0.16% top-1 accuracy, but significantly reduces FLOPs and params."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full model: Acc 1: 0.9286, Acc 5: 0.9974\n",
      "Compressed model: Acc 1: 0.9286, Acc 5: 0.9974\n"
     ]
    }
   ],
   "source": [
    "from utils.utils import check_accuracy_onnx\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=4)\n",
    "\n",
    "acc1_full, acc5_full = check_accuracy(model, testloader)\n",
    "print(\"Full model: Acc 1: {acc1}, Acc 5: {acc5}\".format(acc1=acc1_full, acc5=acc5_full))\n",
    "\n",
    "acc1_compressed, acc5_compressed = check_accuracy_onnx(oto.compressed_model_path, testloader)\n",
    "print(\"Compressed model: Acc 1: {acc1}, Acc 5: {acc5}\".format(acc1=acc1_compressed, acc5=acc5_compressed))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
